{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp tabular.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.all import *\n",
    "from fastai.tabular.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular data\n",
    "\n",
    "> Helper functions to get data in a `DataLoaders` in the tabular application and higher class `TabularDataLoaders`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main class to get your data ready for model training is `TabularDataLoaders` and its factory methods. Checkout the [tabular tutorial](http://docs.fast.ai/tutorial.tabular.html) for examples of use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TabularDataLoaders -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TabularDataLoaders(DataLoaders):\n",
    "    \"Basic wrapper around several `DataLoader`s with factory methods for tabular data\"\n",
    "    @classmethod\n",
    "    @delegates(Tabular.dataloaders, but=[\"dl_type\", \"dl_kwargs\"])\n",
    "    def from_df(cls, \n",
    "        df:pd.DataFrame,\n",
    "        path:str|Path='.', # Location of `df`, defaults to current working directory\n",
    "        procs:list=None, # List of `TabularProc`s\n",
    "        cat_names:list=None, # Column names pertaining to categorical variables\n",
    "        cont_names:list=None, # Column names pertaining to continuous variables\n",
    "        y_names:list=None, # Names of the dependent variables\n",
    "        y_block:TransformBlock=None, # `TransformBlock` to use for the target(s)\n",
    "        valid_idx:list=None, # List of indices to use for the validation set, defaults to a random split\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Create `TabularDataLoaders` from `df` in `path` using `procs`\"\n",
    "        if cat_names is None: cat_names = []\n",
    "        if cont_names is None: cont_names = list(set(df)-set(L(cat_names))-set(L(y_names)))\n",
    "        splits = RandomSplitter()(df) if valid_idx is None else IndexSplitter(valid_idx)(df)\n",
    "        to = TabularPandas(df, procs, cat_names, cont_names, y_names, splits=splits, y_block=y_block)\n",
    "        return to.dataloaders(path=path, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_csv(cls, \n",
    "        csv:str|Path|io.BufferedReader, # A csv of training data\n",
    "        skipinitialspace:bool=True, # Skip spaces after delimiter\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Create `TabularDataLoaders` from `csv` file in `path` using `procs`\"\n",
    "        return cls.from_df(pd.read_csv(csv, skipinitialspace=skipinitialspace), **kwargs)\n",
    "\n",
    "    @delegates(TabDataLoader.__init__)\n",
    "    def test_dl(self, \n",
    "        test_items, # Items to create new test `TabDataLoader` formatted the same as the training data\n",
    "        rm_type_tfms=None, # Number of `Transform`s to be removed from `procs`\n",
    "        process:bool=True, # Apply validation `TabularProc`s to `test_items` immediately\n",
    "        inplace:bool=False, # Keep separate copy of original `test_items` in memory if `False`\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"Create test `TabDataLoader` from `test_items` using validation `procs`\"\n",
    "        to = self.train_ds.new(test_items, inplace=inplace)\n",
    "        if process: to.process()\n",
    "        return self.valid.new(to, **kwargs)\n",
    "\n",
    "Tabular._dbunch_type = TabularDataLoaders\n",
    "TabularDataLoaders.from_csv = delegates(to=TabularDataLoaders.from_df)(TabularDataLoaders.from_csv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:\n",
    "\n",
    "- `cat_names`: the names of the categorical variables\n",
    "- `cont_names`: the names of the continuous variables\n",
    "- `y_names`: the names of the dependent variables\n",
    "- `y_block`: the `TransformBlock` to use for the target\n",
    "- `valid_idx`: the indices to use for the validation set (defaults to a random split otherwise)\n",
    "- `bs`: the batch size\n",
    "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n",
    "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n",
    "- `n`: overrides the numbers of elements in the dataset\n",
    "- `device`: the PyTorch device to use (defaults to `default_device()`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L19){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.from_df\n",
       "\n",
       ">      TabularDataLoaders.from_df (df:pd.DataFrame, path:str|Path='.',\n",
       ">                                  procs:list=None, cat_names:list=None,\n",
       ">                                  cont_names:list=None, y_names:list=None,\n",
       ">                                  y_block:TransformBlock=None,\n",
       ">                                  valid_idx:list=None, bs:int=64,\n",
       ">                                  shuffle_train:bool=None, shuffle:bool=True,\n",
       ">                                  val_shuffle:bool=False, n:int=None,\n",
       ">                                  device:torch.device=None,\n",
       ">                                  drop_last:bool=None, val_bs:int=None)\n",
       "\n",
       "*Create `TabularDataLoaders` from `df` in `path` using `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| df | pd.DataFrame |  |  |\n",
       "| path | str \\| Path | . | Location of `df`, defaults to current working directory |\n",
       "| procs | list | None | List of `TabularProc`s |\n",
       "| cat_names | list | None | Column names pertaining to categorical variables |\n",
       "| cont_names | list | None | Column names pertaining to continuous variables |\n",
       "| y_names | list | None | Names of the dependent variables |\n",
       "| y_block | TransformBlock | None | `TransformBlock` to use for the target(s) |\n",
       "| valid_idx | list | None | List of indices to use for the validation set, defaults to a random split |\n",
       "| bs | int | 64 | Batch size |\n",
       "| shuffle_train | bool | None | (Deprecated, use `shuffle`) Shuffle training `DataLoader` |\n",
       "| shuffle | bool | True | Shuffle training `DataLoader` |\n",
       "| val_shuffle | bool | False | Shuffle validation `DataLoader` |\n",
       "| n | int | None | Size of `Datasets` used to create `DataLoader` |\n",
       "| device | device | None | Device to put `DataLoaders` |\n",
       "| drop_last | bool | None | Drop last incomplete batch, defaults to `shuffle` |\n",
       "| val_bs | int | None | Validation batch size, defaults to `bs` |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L19){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.from_df\n",
       "\n",
       ">      TabularDataLoaders.from_df (df:pd.DataFrame, path:str|Path='.',\n",
       ">                                  procs:list=None, cat_names:list=None,\n",
       ">                                  cont_names:list=None, y_names:list=None,\n",
       ">                                  y_block:TransformBlock=None,\n",
       ">                                  valid_idx:list=None, bs:int=64,\n",
       ">                                  shuffle_train:bool=None, shuffle:bool=True,\n",
       ">                                  val_shuffle:bool=False, n:int=None,\n",
       ">                                  device:torch.device=None,\n",
       ">                                  drop_last:bool=None, val_bs:int=None)\n",
       "\n",
       "*Create `TabularDataLoaders` from `df` in `path` using `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| df | pd.DataFrame |  |  |\n",
       "| path | str \\| Path | . | Location of `df`, defaults to current working directory |\n",
       "| procs | list | None | List of `TabularProc`s |\n",
       "| cat_names | list | None | Column names pertaining to categorical variables |\n",
       "| cont_names | list | None | Column names pertaining to continuous variables |\n",
       "| y_names | list | None | Names of the dependent variables |\n",
       "| y_block | TransformBlock | None | `TransformBlock` to use for the target(s) |\n",
       "| valid_idx | list | None | List of indices to use for the validation set, defaults to a random split |\n",
       "| bs | int | 64 | Batch size |\n",
       "| shuffle_train | bool | None | (Deprecated, use `shuffle`) Shuffle training `DataLoader` |\n",
       "| shuffle | bool | True | Shuffle training `DataLoader` |\n",
       "| val_shuffle | bool | False | Shuffle validation `DataLoader` |\n",
       "| n | int | None | Size of `Datasets` used to create `DataLoader` |\n",
       "| device | device | None | Device to put `DataLoaders` |\n",
       "| drop_last | bool | None | Drop last incomplete batch, defaults to `shuffle` |\n",
       "| val_bs | int | None | Validation batch size, defaults to `bs` |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TabularDataLoaders.from_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a look on an example with the adult dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>49</td>\n",
       "      <td>Private</td>\n",
       "      <td>101320</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>1902</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;=50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>236746</td>\n",
       "      <td>Masters</td>\n",
       "      <td>14.0</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>10520</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;=50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>96185</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Unmarried</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>32</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>38</td>\n",
       "      <td>Self-emp-inc</td>\n",
       "      <td>112847</td>\n",
       "      <td>Prof-school</td>\n",
       "      <td>15.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Asian-Pac-Islander</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;=50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>42</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>82297</td>\n",
       "      <td>7th-8th</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age         workclass  fnlwgt    education  education-num  \\\n",
       "0   49           Private  101320   Assoc-acdm           12.0   \n",
       "1   44           Private  236746      Masters           14.0   \n",
       "2   38           Private   96185      HS-grad            NaN   \n",
       "3   38      Self-emp-inc  112847  Prof-school           15.0   \n",
       "4   42  Self-emp-not-inc   82297      7th-8th            NaN   \n",
       "\n",
       "       marital-status       occupation   relationship                race  \\\n",
       "0  Married-civ-spouse              NaN           Wife               White   \n",
       "1            Divorced  Exec-managerial  Not-in-family               White   \n",
       "2            Divorced              NaN      Unmarried               Black   \n",
       "3  Married-civ-spouse   Prof-specialty        Husband  Asian-Pac-Islander   \n",
       "4  Married-civ-spouse    Other-service           Wife               Black   \n",
       "\n",
       "      sex  capital-gain  capital-loss  hours-per-week native-country salary  \n",
       "0  Female             0          1902              40  United-States  >=50k  \n",
       "1    Male         10520             0              45  United-States  >=50k  \n",
       "2  Female             0             0              32  United-States   <50k  \n",
       "3    Male             0             0              40  United-States  >=50k  \n",
       "4  Female             0             0              50  United-States   <50k  "
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv', skipinitialspace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [Categorify, FillMissing, Normalize]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
    "                                 y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>education-num_na</th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education-num</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>24.0</td>\n",
       "      <td>121312.998272</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>19.0</td>\n",
       "      <td>198320.000325</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Private</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>66.0</td>\n",
       "      <td>169803.999308</td>\n",
       "      <td>13.0</td>\n",
       "      <td>&gt;=50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Unmarried</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>40.0</td>\n",
       "      <td>799280.980929</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Local-gov</td>\n",
       "      <td>10th</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>18.0</td>\n",
       "      <td>55658.003629</td>\n",
       "      <td>6.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Other-relative</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>30.0</td>\n",
       "      <td>375827.003847</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>20.0</td>\n",
       "      <td>173723.999335</td>\n",
       "      <td>10.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>?</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>21.0</td>\n",
       "      <td>107800.997986</td>\n",
       "      <td>10.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>19.0</td>\n",
       "      <td>263338.000072</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Tech-support</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>35.0</td>\n",
       "      <td>194590.999986</td>\n",
       "      <td>10.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dls.show_batch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L38){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.from_csv\n",
       "\n",
       ">      TabularDataLoaders.from_csv (csv:str|Path|io.BufferedReader,\n",
       ">                                   skipinitialspace:bool=True,\n",
       ">                                   path:str|Path='.', procs:list=None,\n",
       ">                                   cat_names:list=None, cont_names:list=None,\n",
       ">                                   y_names:list=None,\n",
       ">                                   y_block:TransformBlock=None,\n",
       ">                                   valid_idx:list=None, bs:int=64,\n",
       ">                                   shuffle_train:bool=None, shuffle:bool=True,\n",
       ">                                   val_shuffle:bool=False, n:int=None,\n",
       ">                                   device:torch.device=None,\n",
       ">                                   drop_last:bool=None, val_bs:int=None)\n",
       "\n",
       "*Create `TabularDataLoaders` from `csv` file in `path` using `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| csv | str \\| Path \\| io.BufferedReader |  | A csv of training data |\n",
       "| skipinitialspace | bool | True | Skip spaces after delimiter |\n",
       "| path | str \\| Path | . | Location of `df`, defaults to current working directory |\n",
       "| procs | list | None | List of `TabularProc`s |\n",
       "| cat_names | list | None | Column names pertaining to categorical variables |\n",
       "| cont_names | list | None | Column names pertaining to continuous variables |\n",
       "| y_names | list | None | Names of the dependent variables |\n",
       "| y_block | TransformBlock | None | `TransformBlock` to use for the target(s) |\n",
       "| valid_idx | list | None | List of indices to use for the validation set, defaults to a random split |\n",
       "| bs | int | 64 | Batch size |\n",
       "| shuffle_train | bool | None | (Deprecated, use `shuffle`) Shuffle training `DataLoader` |\n",
       "| shuffle | bool | True | Shuffle training `DataLoader` |\n",
       "| val_shuffle | bool | False | Shuffle validation `DataLoader` |\n",
       "| n | int | None | Size of `Datasets` used to create `DataLoader` |\n",
       "| device | torch.device | None | Device to put `DataLoaders` |\n",
       "| drop_last | bool | None | Drop last incomplete batch, defaults to `shuffle` |\n",
       "| val_bs | int | None | Validation batch size, defaults to `bs` |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L38){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.from_csv\n",
       "\n",
       ">      TabularDataLoaders.from_csv (csv:str|Path|io.BufferedReader,\n",
       ">                                   skipinitialspace:bool=True,\n",
       ">                                   path:str|Path='.', procs:list=None,\n",
       ">                                   cat_names:list=None, cont_names:list=None,\n",
       ">                                   y_names:list=None,\n",
       ">                                   y_block:TransformBlock=None,\n",
       ">                                   valid_idx:list=None, bs:int=64,\n",
       ">                                   shuffle_train:bool=None, shuffle:bool=True,\n",
       ">                                   val_shuffle:bool=False, n:int=None,\n",
       ">                                   device:torch.device=None,\n",
       ">                                   drop_last:bool=None, val_bs:int=None)\n",
       "\n",
       "*Create `TabularDataLoaders` from `csv` file in `path` using `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| csv | str \\| Path \\| io.BufferedReader |  | A csv of training data |\n",
       "| skipinitialspace | bool | True | Skip spaces after delimiter |\n",
       "| path | str \\| Path | . | Location of `df`, defaults to current working directory |\n",
       "| procs | list | None | List of `TabularProc`s |\n",
       "| cat_names | list | None | Column names pertaining to categorical variables |\n",
       "| cont_names | list | None | Column names pertaining to continuous variables |\n",
       "| y_names | list | None | Names of the dependent variables |\n",
       "| y_block | TransformBlock | None | `TransformBlock` to use for the target(s) |\n",
       "| valid_idx | list | None | List of indices to use for the validation set, defaults to a random split |\n",
       "| bs | int | 64 | Batch size |\n",
       "| shuffle_train | bool | None | (Deprecated, use `shuffle`) Shuffle training `DataLoader` |\n",
       "| shuffle | bool | True | Shuffle training `DataLoader` |\n",
       "| val_shuffle | bool | False | Shuffle validation `DataLoader` |\n",
       "| n | int | None | Size of `Datasets` used to create `DataLoader` |\n",
       "| device | torch.device | None | Device to put `DataLoaders` |\n",
       "| drop_last | bool | None | Drop last incomplete batch, defaults to `shuffle` |\n",
       "| val_bs | int | None | Validation batch size, defaults to `bs` |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TabularDataLoaders.from_csv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [Categorify, FillMissing, Normalize]\n",
    "dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
    "                                  y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.test_dl\n",
       "\n",
       ">      TabularDataLoaders.test_dl (test_items, rm_type_tfms=None,\n",
       ">                                  process:bool=True, inplace:bool=False, bs=16,\n",
       ">                                  shuffle=False, after_batch=None,\n",
       ">                                  num_workers=0, verbose:bool=False,\n",
       ">                                  do_setup:bool=True, pin_memory=False,\n",
       ">                                  timeout=0, batch_size=None, drop_last=False,\n",
       ">                                  indexed=None, n=None, device=None,\n",
       ">                                  persistent_workers=False,\n",
       ">                                  pin_memory_device='', wif=None,\n",
       ">                                  before_iter=None, after_item=None,\n",
       ">                                  before_batch=None, after_iter=None,\n",
       ">                                  create_batches=None, create_item=None,\n",
       ">                                  create_batch=None, retain=None,\n",
       ">                                  get_idxs=None, sample=None, shuffle_fn=None,\n",
       ">                                  do_batch=None)\n",
       "\n",
       "*Create test `TabDataLoader` from `test_items` using validation `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| test_items |  |  | Items to create new test `TabDataLoader` formatted the same as the training data |\n",
       "| rm_type_tfms | NoneType | None | Number of `Transform`s to be removed from `procs` |\n",
       "| process | bool | True | Apply validation `TabularProc`s to `test_items` immediately |\n",
       "| inplace | bool | False | Keep separate copy of original `test_items` in memory if `False` |\n",
       "| bs | int | 64 | Size of batch |\n",
       "| shuffle | bool | False | Whether to shuffle data |\n",
       "| after_batch | NoneType | None |  |\n",
       "| num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |\n",
       "| verbose | bool | False | Whether to print verbose logs |\n",
       "| do_setup | bool | True | Whether to run `setup()` for batch transform(s) |\n",
       "| pin_memory | bool | False |  |\n",
       "| timeout | int | 0 |  |\n",
       "| batch_size | NoneType | None |  |\n",
       "| drop_last | bool | False |  |\n",
       "| indexed | NoneType | None |  |\n",
       "| n | NoneType | None |  |\n",
       "| device | NoneType | None |  |\n",
       "| persistent_workers | bool | False |  |\n",
       "| pin_memory_device | str |  |  |\n",
       "| wif | NoneType | None |  |\n",
       "| before_iter | NoneType | None |  |\n",
       "| after_item | NoneType | None |  |\n",
       "| before_batch | NoneType | None |  |\n",
       "| after_iter | NoneType | None |  |\n",
       "| create_batches | NoneType | None |  |\n",
       "| create_item | NoneType | None |  |\n",
       "| create_batch | NoneType | None |  |\n",
       "| retain | NoneType | None |  |\n",
       "| get_idxs | NoneType | None |  |\n",
       "| sample | NoneType | None |  |\n",
       "| shuffle_fn | NoneType | None |  |\n",
       "| do_batch | NoneType | None |  |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/data.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularDataLoaders.test_dl\n",
       "\n",
       ">      TabularDataLoaders.test_dl (test_items, rm_type_tfms=None,\n",
       ">                                  process:bool=True, inplace:bool=False, bs=16,\n",
       ">                                  shuffle=False, after_batch=None,\n",
       ">                                  num_workers=0, verbose:bool=False,\n",
       ">                                  do_setup:bool=True, pin_memory=False,\n",
       ">                                  timeout=0, batch_size=None, drop_last=False,\n",
       ">                                  indexed=None, n=None, device=None,\n",
       ">                                  persistent_workers=False,\n",
       ">                                  pin_memory_device='', wif=None,\n",
       ">                                  before_iter=None, after_item=None,\n",
       ">                                  before_batch=None, after_iter=None,\n",
       ">                                  create_batches=None, create_item=None,\n",
       ">                                  create_batch=None, retain=None,\n",
       ">                                  get_idxs=None, sample=None, shuffle_fn=None,\n",
       ">                                  do_batch=None)\n",
       "\n",
       "*Create test `TabDataLoader` from `test_items` using validation `procs`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| test_items |  |  | Items to create new test `TabDataLoader` formatted the same as the training data |\n",
       "| rm_type_tfms | NoneType | None | Number of `Transform`s to be removed from `procs` |\n",
       "| process | bool | True | Apply validation `TabularProc`s to `test_items` immediately |\n",
       "| inplace | bool | False | Keep separate copy of original `test_items` in memory if `False` |\n",
       "| bs | int | 64 | Size of batch |\n",
       "| shuffle | bool | False | Whether to shuffle data |\n",
       "| after_batch | NoneType | None |  |\n",
       "| num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |\n",
       "| verbose | bool | False | Whether to print verbose logs |\n",
       "| do_setup | bool | True | Whether to run `setup()` for batch transform(s) |\n",
       "| pin_memory | bool | False |  |\n",
       "| timeout | int | 0 |  |\n",
       "| batch_size | NoneType | None |  |\n",
       "| drop_last | bool | False |  |\n",
       "| indexed | NoneType | None |  |\n",
       "| n | NoneType | None |  |\n",
       "| device | NoneType | None |  |\n",
       "| persistent_workers | bool | False |  |\n",
       "| pin_memory_device | str |  |  |\n",
       "| wif | NoneType | None |  |\n",
       "| before_iter | NoneType | None |  |\n",
       "| after_item | NoneType | None |  |\n",
       "| before_batch | NoneType | None |  |\n",
       "| after_iter | NoneType | None |  |\n",
       "| create_batches | NoneType | None |  |\n",
       "| create_item | NoneType | None |  |\n",
       "| create_batch | NoneType | None |  |\n",
       "| retain | NoneType | None |  |\n",
       "| get_idxs | NoneType | None |  |\n",
       "| sample | NoneType | None |  |\n",
       "| shuffle_fn | NoneType | None |  |\n",
       "| do_batch | NoneType | None |  |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TabularDataLoaders.test_dl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "External structured data files can contain unexpected spaces, e.g. after a comma. We can see that in the first row of adult.csv `\"49, Private,101320, ...\"`. Often trimming is needed. Pandas has a convenient parameter `skipinitialspace` that is exposed by `TabularDataLoaders.from_csv()`. Otherwise category labels use for inference later such as `workclass`:`Private` will be categorized wrongly to *0* or `\"#na#\"` if training label was read as `\" Private\"`. Let's test this feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = {\n",
    "    'age': [49], \n",
    "    'workclass': ['Private'], \n",
    "    'fnlwgt': [101320],\n",
    "    'education': ['Assoc-acdm'], \n",
    "    'education-num': [12.0],\n",
    "    'marital-status': ['Married-civ-spouse'], \n",
    "    'occupation': [''],\n",
    "    'relationship': ['Wife'],\n",
    "    'race': ['White'],\n",
    "}\n",
    "input = pd.DataFrame(test_data)\n",
    "tdl = dls.test_dl(input)\n",
    "\n",
    "test_ne(0, tdl.dataset.iloc[0]['workclass'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
